# import
import numpy as np
import pandas as pd
from scipy.sparse import find
from numba import njit

import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from util.models import *
from util.misc import *

## mat file
from scipy.io import loadmat


def preprocess_df(df):
    """Preprocess the dataframe

    Parameters
    ----------
    df : pandas dataframe
        Dataframe containing the experimental data
        the columns of df should include 'z1', 'z2', and 'p1'
        'expected value' and 'id_outcome' are to be added to the column

    Returns
    -------
    df_lottery : pandas dataframe
        columns : z1, z2, p1, expected value, id_outcome
        Dataframe containing the possible combinations of (z1, z2, p1)
        id_outcome is an index for each outcome (z1, z2)
        Rows are sorted in the lexicographical order

    TO DO
    -----
    Check if the column 'lottery' is unnecessary or not
    Run the notebook with the dataframe that does not have 'lottery' column
    """
    df_lottery = df[['z1', 'z2', 'p1']].drop_duplicates()
    df_lottery = df_lottery.sort_values(['z1', 'z2', 'p1'])
    df_lottery = df_lottery.reset_index(drop=True)
    df_lottery['expected_value'] = df_lottery['z1']*df_lottery['p1'] + df_lottery['z2'] * (1.0 - df_lottery['p1'])

    # put id's for each outcome
    df_outcome = df_lottery[['z1', 'z2']].drop_duplicates()
    num_outcome = df_outcome.shape[0]
    id_outcome = np.arange(num_outcome)
    df_outcome['id_outcome'] = id_outcome
    df_outcome.reset_index(drop=True)
    df_lottery = pd.merge(df_lottery, df_outcome, on=['z1', 'z2'], how='left')

    return df_lottery


def generate_FOSD_matrix(df_lottery):
    """Generate a matrix used to check FOSD condition later

    Parameters
    ----------
    df_lottery : pandas dataframe
        See preprocess_df

    Returns
    -------
    fosd : ndarray(int, ndim=2)
        Array containing 0, 1 elements.
        fosd[i, j] = 1 iff lottery i FOSD lottery j, and the supports of i and j are overlapped.

    Notes
    -----
        If the supports of lottery i and j are not overlapped, we don't need to check FOSD relation of them.
        For the moment, I assume that the lottery does not include losses, and z1 > z2.
    """
    num_lottery = df_lottery.shape[0] # n

    # check FOSD relation
    fosd = np.zeros((num_lottery, num_lottery), dtype=int)
    z1 = df_lottery['z1']
    z2 = df_lottery['z2']
    p1 = df_lottery['p1']
    for l1 in range(num_lottery):
        for l2 in range(num_lottery):
            if z1[l1] > z1[l2] and z2[l1] >= z2[l2] and p1[l1] >= p1[l2]:
                fosd[l1, l2] = 1
            elif z1[l1] >= z1[l2] and z2[l1] > z2[l2] and p1[l1] >= p1[l2]:
                fosd[l1, l2] = 1
            elif z1[l1] >= z1[l2] and z2[l1] >= z2[l2] and p1[l1] > p1[l2]:
                fosd[l1, l2] = 1

    # check if supports are overlapping
    ## overlap[lottery1, lottery2] = 1 iff the supports of the lotteries overlap
    ## this is for speeding up the code -- reduce the number of lotteries to be checked

    overlap = np.zeros((num_lottery, num_lottery), dtype=int)
    for l1 in range(num_lottery):
        for l2 in range(num_lottery):
            if z2[l1] < z1[l2]:
                overlap[l1, l2] = 1

    # FOSD restriction and overlap
    fosd = fosd * overlap

    return fosd


def generate_fake_ce(df_lottery):
    """Randomly sample certainty equivalents from the admissible set for each lottery

    Parameters
    ----------
    df_lottery : pandas dataframe
        See preprocess_df

    Returns
    -------
    ce : ndarray(float, ndim=1)
        Generated random certainty equivalents
        CEs are sorted so that lotteries with the same outcome at least do not violate FOSD restriction.
    """
    df_ce = df_lottery.copy()
    df_ce['ce'] = 0
    num_lottery = df_ce.shape[0]
    rvs = np.random.rand(num_lottery)
    id_outcome_list = np.unique(df_ce['id_outcome'].values)
    z1 = df_ce['z1'].values.astype(float)
    z2 = df_ce['z2'].values.astype(float)
    ce = np.zeros_like(z1)
    for lottery in range(num_lottery):
        ce[lottery] = z2[lottery] + (z1[lottery] - z2[lottery])*rvs[lottery]

    # For lotteries that share the same outcome, reorder their CEs so that they don't violate FOSD
    for id_outcome in id_outcome_list:
        temp_index = (df_ce['id_outcome'] == id_outcome)
        temp_ce = ce[temp_index]
        temp_ce = np.sort(temp_ce)
        df_ce.loc[temp_index, 'ce'] = temp_ce
    ce = df_ce['ce'].values
    return ce


def pass_FOSD_criterion(ce, row_fosd, col_fosd):
    """Return true if the CE's do not violate FOSD criterion
    The order of ce and (row_fosd, col_fosd) matters.

    Parameters
    ----------
    ce : ndarray(float, ndim=1)
        See generate_fake_ce

    row_fosd : ndarray(int, ndim=1)
    col_fosd : ndarray(int, ndim=1)
        Each array has the same length, which is equal to the number of lotteries that violates FOSD restriction
        lottery row_fosd[k] FOSD lottery col_fosd[k]

    Returns
    -------
    retval : bool
        True iff CEs do not violate FOSD criterion
    """
    ret_val = True
    for lottery_1, lottery_2 in zip(row_fosd, col_fosd):
        if ce[lottery_1] <= ce[lottery_2]:
            ret_val = False
            break
    return ret_val


# @njit
def closest_CPT(z1, z2, p1, ce, param_grid):
    """Return the parameter that best approximates the generated CE and the error under the param

    Parameters
    ----------
    z1, z2, p1: ndarray(float, ndim=1)
    ce : ndarray(float, ndim=1)
        See generate_fake_ce
    param_grid : ndarray(float, ndim=1)
        Array containing all possible parameter combinations

    Returns
    -------
    model_error : float
    best_param : ndarray(float, ndim=3)
    """
    model_error = np.inf
    best_param = np.zeros_like(param_grid[0])

    for param in param_grid:
        pred_val = pred_CPT(z1, z2, p1, param)
        error = ((ce - pred_val)**2).mean()
        if error < model_error:
            model_error = error
            best_param = param

    return model_error, best_param


# @njit
def closest_DA(z1, z2, p1, ce, param_grid):
    """Return the parameter that best approximates the generated CE and the error under the param

    Parameters
    ----------
    z1, z2, p1: ndarray(float, ndim=1)
    ce : ndarray(float, ndim=1)
        See generate_fake_ce
    param_grid : ndarray(float, ndim=1)
        Array containing all possible parameter combinations

    Returns
    -------
    model_error : float
    best_param : ndarray(float, ndim=2)
    """
    model_error = np.inf
    best_param = np.zeros_like(param_grid[0])

    for param in param_grid:
        pred_val = pred_DA(z1, z2, p1, param)
        error = ((ce - pred_val)**2).mean()
        if error < model_error:
            model_error = error
            best_param = param

    return model_error, best_param


# @njit
def compute_restrictiveness(z1, z2, p1, expval, ce_data, param_grid, closest=closest_CPT):
    """Compute the restrictiveness

    Parameters
    ----------
    z1, z2, p1: ndarray(float, ndim=1)
    ce_data : ndarray(float, ndim=2)
        Array containing the generated certainty equivalents data
        each column contains one set of generated CEs
        the number of columns = sample size
        See also generate_fake_ce
    param_grid : ndarray(float, ndim=1)
        Array containing all possible parameter combinations
    closest : function, optional
        See closest_CPT, by default closest_CPT
    sample_size : int, optional by default 1000

    Returns
    -------
    restrictiveness : float
    stderr : float
    """
    sample_size = ce_data.shape[1]
    base_errors = np.zeros(sample_size)
    model_errors = np.zeros(sample_size)

    for iter in range(sample_size):
        ce = ce_data[:, iter]
        base_errors[iter] = ((ce - expval)**2).mean()
        model_error, best_param = closest(z1, z2, p1, ce, param_grid)
        model_errors[iter] = model_error

    base_error = base_errors.mean()
    model_error = model_errors.mean()
    restrictiveness = model_error/base_error

    # standard error
    model_var = np.var(model_errors)
    base_var = np.var(base_errors)
    covar = np.cov(model_errors, base_errors)[0,1]
    var_r = (model_var - 2*restrictiveness*covar + (restrictiveness**2)*base_var)/(base_error**2)
    stderr = np.sqrt(var_r/sample_size)

    # numba does not support print .format
    # print('restrictiveness: {}'.format(restrictiveness))
    # print('stderr: {}'.format(stderr))

    return restrictiveness, stderr


if __name__ == '__main__':
    # import
    import numpy as np
    import pandas as pd
    from scipy.sparse import find

    ## mat file
    from scipy.io import loadmat

    # load dataset
    base = os.path.dirname(os.path.abspath(__file__))
    file_name = os.path.normpath(os.path.join(base, '../data/Bruhin_et_al_2010.mat'))
    matdata = loadmat(file_name)

    df_full = pd.DataFrame()
    for key in matdata.keys():
        if not '__' in key:
            df_full[key] = matdata[key].flatten()

    ## drop lotteries over losses
    df = df_full[df_full['z2'] >= 0]

    df_lottery = preprocess_df(df)
    z1 = df_lottery['z1'].values
    z2 = df_lottery['z2'].values
    p1 = df_lottery['p1'].values
    expval = df_lottery['expected_value'].values
    ce_data = df_ce.values
    df_ce = pd.read_csv('generated_ce.csv')
    ce_data = df_ce.values

    # set search space for parameters
    alpha_grid_size = 0.01
    gamma_grid_size = 0.01
    delta_grid_size = 0.1
    alpha_ubd = 1
    gamma_ubd = 1
    delta_ubd = 30

    alpha_grid = np.arange(alpha_grid_size, alpha_ubd+alpha_grid_size, alpha_grid_size)
    gamma_grid = np.arange(gamma_grid_size, gamma_ubd+gamma_grid_size, gamma_grid_size)
    delta_grid = np.arange(delta_grid_size, delta_ubd+delta_grid_size, delta_grid_size)
    param_grid = cartesian_product(alpha_grid, gamma_grid, delta_grid)

    restrictiveness, stderr = compute_restrictiveness(z1, z2, p1, expval, ce_data, param_grid)
    print('restrictiveness: {}'.format(restrictiveness))
    print('stderr: {}'.format(stderr))
